{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<a href=\"https://colab.research.google.com/github/jerryjliu/llama_index/blob/main/docs/examples/finetuning/openai_fine_tuning_functions.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Fine Tuning with Function Calling\n",
    "\n",
    "In this notebook, we walk through how to fine-tune gpt-3.5-turbo with function calls. The primary use case here is structured data extraction. Our main focus is distilling GPT-4 outputs to help improve gpt-3.5-turbo function calling capabilities.\n",
    "\n",
    "We will walk through some examples, from simple to advanced:\n",
    "1. Fine-tuning on some toy messages/structured outputs logged through our OpenAI Pydantic Program object.\n",
    "2. Fine-tuning on context-augmented queries/structured outputs over an entire document corpus. Use this in a RAG system."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%pip install llama-index-finetuning\n",
    "%pip install llama-index-llms-openai\n",
    "%pip install llama-index-finetuning-callbacks\n",
    "%pip install llama-index-readers-file\n",
    "%pip install llama-index-program-openai"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import nest_asyncio\n",
    "\n",
    "nest_asyncio.apply()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import openai"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "os.environ[\"OPENAI_API_KEY\"] = \"sk-...\"\n",
    "openai.api_key = os.environ[\"OPENAI_API_KEY\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Fine-tuning Using GPT-4 Pydantic Programs\n",
    "\n",
    "In this section we show how to log inputs/outputs through our low-level Pydantic Program module. We use that dataset to fine-tune an LLM."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Defining Pydantic Model + Program\n",
    "\n",
    "Here, we define the GPT-4 powered function calling program that will generate structured outputs into a Pydantic object (an Album)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.program.openai import OpenAIPydanticProgram\n",
    "from pydantic import BaseModel\n",
    "from llama_index.llms.openai import OpenAI\n",
    "from llama_index.finetuning.callbacks import OpenAIFineTuningHandler\n",
    "from llama_index.core.callbacks import CallbackManager\n",
    "from typing import List\n",
    "\n",
    "\n",
    "class Song(BaseModel):\n",
    "    \"\"\"Data model for a song.\"\"\"\n",
    "\n",
    "    title: str\n",
    "    length_seconds: int\n",
    "\n",
    "\n",
    "class Album(BaseModel):\n",
    "    \"\"\"Data model for an album.\"\"\"\n",
    "\n",
    "    name: str\n",
    "    artist: str\n",
    "    songs: List[Song]\n",
    "\n",
    "\n",
    "finetuning_handler = OpenAIFineTuningHandler()\n",
    "callback_manager = CallbackManager([finetuning_handler])\n",
    "\n",
    "llm = OpenAI(model=\"gpt-4\", callback_manager=callback_manager)\n",
    "\n",
    "\n",
    "prompt_template_str = \"\"\"\\\n",
    "Generate an example album, with an artist and a list of songs. \\\n",
    "Using the movie {movie_name} as inspiration.\\\n",
    "\"\"\"\n",
    "program = OpenAIPydanticProgram.from_defaults(\n",
    "    output_cls=Album,\n",
    "    prompt_template_str=prompt_template_str,\n",
    "    llm=llm,\n",
    "    verbose=False,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Log Inputs/Outputs\n",
    "\n",
    "We define some sample movie names as inputs and log the outputs through the function calling program."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# NOTE: we need >= 10 movies to use OpenAI fine-tuning\n",
    "movie_names = [\n",
    "    \"The Shining\",\n",
    "    \"The Departed\",\n",
    "    \"Titanic\",\n",
    "    \"Goodfellas\",\n",
    "    \"Pretty Woman\",\n",
    "    \"Home Alone\",\n",
    "    \"Caged Fury\",\n",
    "    \"Edward Scissorhands\",\n",
    "    \"Total Recall\",\n",
    "    \"Ghost\",\n",
    "    \"Tremors\",\n",
    "    \"RoboCop\",\n",
    "    \"Rocky V\",\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/json": {
       "ascii": false,
       "bar_format": null,
       "colour": null,
       "elapsed": 0.004143953323364258,
       "initial": 0,
       "n": 0,
       "ncols": null,
       "nrows": 25,
       "postfix": null,
       "prefix": "",
       "rate": null,
       "total": 13,
       "unit": "it",
       "unit_divisor": 1000,
       "unit_scale": false
      },
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "0c6e3e3e2da545d1a5bb23e93d867444",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/13 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{\"name\": \"The Shining\", \"artist\": \"Various Artists\", \"songs\": [{\"title\": \"Main Title\", \"length_seconds\": 180}, {\"title\": \"Opening Credits\", \"length_seconds\": 120}, {\"title\": \"The Overlook Hotel\", \"length_seconds\": 240}, {\"title\": \"Redrum\", \"length_seconds\": 150}, {\"title\": \"Here's Johnny!\", \"length_seconds\": 200}]}\n",
      "{\"name\": \"The Departed Soundtrack\", \"artist\": \"Various Artists\", \"songs\": [{\"title\": \"Gimme Shelter\", \"length_seconds\": 272}, {\"title\": \"Comfortably Numb\", \"length_seconds\": 383}, {\"title\": \"I'm Shipping Up to Boston\", \"length_seconds\": 166}, {\"title\": \"Sweet Dreams (Are Made of This)\", \"length_seconds\": 216}, {\"title\": \"I'm Shipping Up to Boston (Instrumental)\", \"length_seconds\": 166}, {\"title\": \"The Departed Tango\", \"length_seconds\": 123}, {\"title\": \"Thief's Theme\", \"length_seconds\": 201}, {\"title\": \"Well Well Well\", \"length_seconds\": 126}, {\"title\": \"Comfortably Numb (Live)\", \"length_seconds\": 383}, {\"title\": \"Sail On, Sailor\", \"length_seconds\": 181}]}\n",
      "{\"name\": \"Titanic Soundtrack\", \"artist\": \"James Horner\", \"songs\": [{\"title\": \"My Heart Will Go On\", \"length_seconds\": 273}, {\"title\": \"Rose\", \"length_seconds\": 120}, {\"title\": \"Hymn to the Sea\", \"length_seconds\": 365}, {\"title\": \"Southampton\", \"length_seconds\": 180}, {\"title\": \"Take Her to Sea, Mr. Murdoch\", \"length_seconds\": 150}]}\n",
      "{\"name\": \"Goodfellas Soundtrack\", \"artist\": \"Various Artists\", \"songs\": [{\"title\": \"Rags to Riches\", \"length_seconds\": 180}, {\"title\": \"Gimme Shelter\", \"length_seconds\": 270}, {\"title\": \"Layla\", \"length_seconds\": 270}, {\"title\": \"Jump into the Fire\", \"length_seconds\": 240}, {\"title\": \"Atlantis\", \"length_seconds\": 180}, {\"title\": \"Beyond the Sea\", \"length_seconds\": 180}, {\"title\": \"Sunshine of Your Love\", \"length_seconds\": 240}, {\"title\": \"Mannish Boy\", \"length_seconds\": 240}, {\"title\": \"Layla (Piano Exit)\", \"length_seconds\": 120}]}\n",
      "{\"name\": \"Pretty Woman Soundtrack\", \"artist\": \"Various Artists\", \"songs\": [{\"title\": \"Oh, Pretty Woman\", \"length_seconds\": 178}, {\"title\": \"King of Wishful Thinking\", \"length_seconds\": 253}, {\"title\": \"It Must Have Been Love\", \"length_seconds\": 250}, {\"title\": \"Show Me Your Soul\", \"length_seconds\": 285}, {\"title\": \"No Explanation\", \"length_seconds\": 244}]}\n",
      "{\"name\": \"Home Alone Soundtrack\", \"artist\": \"John Williams\", \"songs\": [{\"title\": \"Somewhere in My Memory\", \"length_seconds\": 180}, {\"title\": \"Holiday Flight\", \"length_seconds\": 120}, {\"title\": \"The House\", \"length_seconds\": 150}, {\"title\": \"Star of Bethlehem\", \"length_seconds\": 135}, {\"title\": \"Setting the Trap\", \"length_seconds\": 165}, {\"title\": \"The Attack on the House\", \"length_seconds\": 200}, {\"title\": \"Mom Returns and Finale\", \"length_seconds\": 240}]}\n",
      "{\"name\": \"Caged Fury\", \"artist\": \"The Fury Band\", \"songs\": [{\"title\": \"Caged Fury\", \"length_seconds\": 240}, {\"title\": \"Prison Break\", \"length_seconds\": 180}, {\"title\": \"Behind Bars\", \"length_seconds\": 210}, {\"title\": \"Escape Plan\", \"length_seconds\": 195}, {\"title\": \"Fight for Freedom\", \"length_seconds\": 220}]}\n",
      "{\"name\": \"Edward Scissorhands Soundtrack\", \"artist\": \"Danny Elfman\", \"songs\": [{\"title\": \"Introduction\", \"length_seconds\": 120}, {\"title\": \"Ice Dance\", \"length_seconds\": 180}, {\"title\": \"Edwardo the Barber\", \"length_seconds\": 150}, {\"title\": \"The Grand Finale\", \"length_seconds\": 240}]}\n",
      "{\"name\": \"Total Recall\", \"artist\": \"Various Artists\", \"songs\": [{\"title\": \"Recall\", \"length_seconds\": 240}, {\"title\": \"Mars\", \"length_seconds\": 180}, {\"title\": \"Memory\", \"length_seconds\": 210}, {\"title\": \"Rebellion\", \"length_seconds\": 300}, {\"title\": \"Escape\", \"length_seconds\": 270}]}\n",
      "{\"name\": \"Ghost\", \"artist\": \"Various Artists\", \"songs\": [{\"title\": \"Unchained Melody\", \"length_seconds\": 218}, {\"title\": \"Oh My Love\", \"length_seconds\": 156}, {\"title\": \"Ditto's Theme\", \"length_seconds\": 92}, {\"title\": \"Love Inside\", \"length_seconds\": 180}, {\"title\": \"Ghostly Encounter\", \"length_seconds\": 120}]}\n",
      "{\"name\": \"Tremors Soundtrack\", \"artist\": \"Various Artists\", \"songs\": [{\"title\": \"Main Theme\", \"length_seconds\": 180}, {\"title\": \"Graboids Attack\", \"length_seconds\": 240}, {\"title\": \"Val and Earl's Theme\", \"length_seconds\": 200}, {\"title\": \"Burt's Arsenal\", \"length_seconds\": 220}, {\"title\": \"Nest of the Graboids\", \"length_seconds\": 190}]}\n",
      "{\"name\": \"RoboCop: The Soundtrack\", \"artist\": \"Various Artists\", \"songs\": [{\"title\": \"Main Theme\", \"length_seconds\": 180}, {\"title\": \"Murphy's Death\", \"length_seconds\": 240}, {\"title\": \"RoboCop's Training\", \"length_seconds\": 210}, {\"title\": \"ED-209\", \"length_seconds\": 195}, {\"title\": \"Clarence Boddicker\", \"length_seconds\": 220}, {\"title\": \"RoboCop Saves the Day\", \"length_seconds\": 240}, {\"title\": \"RoboCop's Theme\", \"length_seconds\": 180}]}\n",
      "{\"name\": \"Rocky V\", \"artist\": \"Various Artists\", \"songs\": [{\"title\": \"Measure of a Man\", \"length_seconds\": 240}, {\"title\": \"Can't Stop the Fire\", \"length_seconds\": 210}, {\"title\": \"Go for It!\", \"length_seconds\": 180}, {\"title\": \"Take You Back (Home Sweet Home)\", \"length_seconds\": 200}, {\"title\": \"The Measure of a Man (Reprise)\", \"length_seconds\": 120}]}\n"
     ]
    }
   ],
   "source": [
    "from tqdm.notebook import tqdm\n",
    "\n",
    "for movie_name in tqdm(movie_names):\n",
    "    output = program(movie_name=movie_name)\n",
    "    print(output.json())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Wrote 14 examples to mock_finetune_songs.jsonl\n"
     ]
    }
   ],
   "source": [
    "finetuning_handler.save_finetuning_events(\"mock_finetune_songs.jsonl\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!cat mock_finetune_songs.jsonl"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Fine-tune on the Dataset\n",
    "\n",
    "We now define a fine-tuning engine and fine-tune on the mock dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.finetuning import OpenAIFinetuneEngine\n",
    "\n",
    "finetune_engine = OpenAIFinetuneEngine(\n",
    "    \"gpt-3.5-turbo\",\n",
    "    \"mock_finetune_songs.jsonl\",\n",
    "    # start_job_id=\"<start-job-id>\"  # if you have an existing job, can specify id here\n",
    "    validate_json=False,  # openai validate json code doesn't support function calling yet\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "finetune_engine.finetune()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<FineTuningJob fine_tuning.job id=ftjob-uJ9kQ9pI0p0YNatBDxF3VITv at 0x172a5c9a0> JSON: {\n",
       "  \"object\": \"fine_tuning.job\",\n",
       "  \"id\": \"ftjob-uJ9kQ9pI0p0YNatBDxF3VITv\",\n",
       "  \"model\": \"gpt-3.5-turbo-0613\",\n",
       "  \"created_at\": 1696463378,\n",
       "  \"finished_at\": 1696463749,\n",
       "  \"fine_tuned_model\": \"ft:gpt-3.5-turbo-0613:llamaindex::8660TXqx\",\n",
       "  \"organization_id\": \"org-1ZDAvajC6v2ZtAP9hLEIsXRz\",\n",
       "  \"result_files\": [\n",
       "    \"file-Hbpw15BAwyf3e4HK5Z9g4IK2\"\n",
       "  ],\n",
       "  \"status\": \"succeeded\",\n",
       "  \"validation_file\": null,\n",
       "  \"training_file\": \"file-MNh7snhv0triDIhsrErokSMY\",\n",
       "  \"hyperparameters\": {\n",
       "    \"n_epochs\": 7\n",
       "  },\n",
       "  \"trained_tokens\": 22834,\n",
       "  \"error\": null\n",
       "}"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "finetune_engine.get_current_job()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Try it Out! \n",
    "\n",
    "We obtain the fine-tuned LLM and use it with the Pydantic program."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ft_llm = finetune_engine.get_finetuned_model(temperature=0.3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ft_program = OpenAIPydanticProgram.from_defaults(\n",
    "    output_cls=Album,\n",
    "    prompt_template_str=prompt_template_str,\n",
    "    llm=ft_llm,\n",
    "    verbose=False,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Album(name='Goodfellas Soundtrack', artist='Various Artists', songs=[Song(title='Rags to Riches', length_seconds=180), Song(title='Gimme Shelter', length_seconds=270), Song(title='Layla', length_seconds=270), Song(title='Jump into the Fire', length_seconds=240), Song(title='Atlantis', length_seconds=180), Song(title='Beyond the Sea', length_seconds=180), Song(title='Sunshine of Your Love', length_seconds=240), Song(title='Mannish Boy', length_seconds=240), Song(title='Layla (Piano Exit)', length_seconds=120)])"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ft_program(movie_name=\"Goodfellas\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Fine-tuning Structured Outputs through a RAG System\n",
    "\n",
    "A use case of function calling is to get structured outputs through a RAG system.\n",
    "\n",
    "Here we show how to create a training dataset of context-augmented inputs + structured outputs over an unstructured document. We can then fine-tune the LLM and plug it into a RAG system to perform retrieval + output extraction."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--2023-10-04 23:46:36--  https://arxiv.org/pdf/2307.09288.pdf\n",
      "Resolving arxiv.org (arxiv.org)... 128.84.21.199\n",
      "Connecting to arxiv.org (arxiv.org)|128.84.21.199|:443... connected.\n",
      "HTTP request sent, awaiting response... 200 OK\n",
      "Length: 13661300 (13M) [application/pdf]\n",
      "Saving to: ‘data/llama2.pdf’\n",
      "\n",
      "data/llama2.pdf     100%[===================>]  13.03M   229KB/s    in 45s     \n",
      "\n",
      "2023-10-04 23:47:25 (298 KB/s) - ‘data/llama2.pdf’ saved [13661300/13661300]\n",
      "\n"
     ]
    }
   ],
   "source": [
    "!mkdir data && wget --user-agent \"Mozilla\" \"https://arxiv.org/pdf/2307.09288.pdf\" -O \"data/llama2.pdf\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pydantic import Field\n",
    "from typing import List\n",
    "\n",
    "\n",
    "class Citation(BaseModel):\n",
    "    \"\"\"Citation class.\"\"\"\n",
    "\n",
    "    author: str = Field(\n",
    "        ..., description=\"Inferred first author (usually last name\"\n",
    "    )\n",
    "    year: int = Field(..., description=\"Inferred year\")\n",
    "    desc: str = Field(\n",
    "        ...,\n",
    "        description=(\n",
    "            \"Inferred description from the text of the work that the author is\"\n",
    "            \" cited for\"\n",
    "        ),\n",
    "    )\n",
    "\n",
    "\n",
    "class Response(BaseModel):\n",
    "    \"\"\"List of author citations.\n",
    "\n",
    "    Extracted over unstructured text.\n",
    "\n",
    "    \"\"\"\n",
    "\n",
    "    citations: List[Citation] = Field(\n",
    "        ...,\n",
    "        description=(\n",
    "            \"List of author citations (organized by author, year, and\"\n",
    "            \" description).\"\n",
    "        ),\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load Data + Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.readers.file import PyMuPDFReader\n",
    "from llama_index.core import Document\n",
    "from llama_index.core.node_parser import SentenceSplitter\n",
    "from pathlib import Path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "loader = PyMuPDFReader()\n",
    "docs0 = loader.load(file_path=Path(\"./data/llama2.pdf\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "doc_text = \"\\n\\n\".join([d.get_content() for d in docs0])\n",
    "metadata = {\n",
    "    \"paper_title\": \"Llama 2: Open Foundation and Fine-Tuned Chat Models\"\n",
    "}\n",
    "docs = [Document(text=doc_text, metadata=metadata)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "chunk_size = 1024\n",
    "node_parser = SentenceSplitter(chunk_size=chunk_size)\n",
    "nodes = node_parser.get_nodes_from_documents(docs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "89"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(nodes)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.core import Settings\n",
    "\n",
    "finetuning_handler = OpenAIFineTuningHandler()\n",
    "callback_manager = CallbackManager([finetuning_handler])\n",
    "\n",
    "Settings.chunk_size = chunk_size\n",
    "\n",
    "gpt_4_llm = OpenAI(\n",
    "    model=\"gpt-4-0613\", temperature=0.3, callback_manager=callback_manager\n",
    ")\n",
    "\n",
    "gpt_35_llm = OpenAI(\n",
    "    model=\"gpt-3.5-turbo-0613\",\n",
    "    temperature=0.3,\n",
    "    callback_manager=callback_manager,\n",
    ")\n",
    "\n",
    "eval_llm = OpenAI(model=\"gpt-4-0613\", temperature=0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Generate Dataset\n",
    "\n",
    "Here we show how to generate a training dataset over these unstructured chunks/nodes.\n",
    "\n",
    "We generate questions to extract citations over different context. We run these questions through a GPT-4 RAG pipeline, extract structured outputs, and log inputs/outputs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# setup dataset generator\n",
    "from llama_index.core.evaluation import DatasetGenerator\n",
    "from llama_index.core import SummaryIndex\n",
    "from llama_index.core import PromptTemplate\n",
    "from tqdm.notebook import tqdm\n",
    "from tqdm.asyncio import tqdm_asyncio\n",
    "\n",
    "\n",
    "fp = open(\"data/qa_pairs.jsonl\", \"w\")\n",
    "\n",
    "question_gen_prompt = PromptTemplate(\n",
    "    \"\"\"\n",
    "{query_str}\n",
    "\n",
    "Context:\n",
    "{context_str}\n",
    "\n",
    "Questions:\n",
    "\"\"\"\n",
    ")\n",
    "\n",
    "question_gen_query = \"\"\"\\\n",
    "Snippets from a research paper is given below. It contains citations.\n",
    "Please generate questions from the text asking about these citations.\n",
    "\n",
    "For instance, here are some sample questions:\n",
    "Which citations correspond to related works on transformer models? \n",
    "Tell me about authors that worked on advancing RLHF.\n",
    "Can you tell me citations corresponding to all computer vision works? \\\n",
    "\"\"\"\n",
    "\n",
    "qr_pairs = []\n",
    "node_questions_tasks = []\n",
    "for idx, node in enumerate(nodes[:39]):\n",
    "    num_questions = 1  # change this number to increase number of nodes\n",
    "    dataset_generator = DatasetGenerator(\n",
    "        [node],\n",
    "        question_gen_query=question_gen_query,\n",
    "        text_question_template=question_gen_prompt,\n",
    "        llm=eval_llm,\n",
    "        metadata_mode=\"all\",\n",
    "        num_questions_per_chunk=num_questions,\n",
    "    )\n",
    "\n",
    "    task = dataset_generator.agenerate_questions_from_nodes(num=num_questions)\n",
    "    node_questions_tasks.append(task)\n",
    "node_questions_lists = await tqdm_asyncio.gather(*node_questions_tasks)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "node_questions_lists"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.core import VectorStoreIndex\n",
    "\n",
    "gpt4_index = VectorStoreIndex(nodes=nodes)\n",
    "gpt4_query_engine = gpt4_index.as_query_engine(\n",
    "    output_cls=Response, similarity_top_k=1, llm=gpt_4_llm\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/json": {
       "ascii": false,
       "bar_format": null,
       "colour": null,
       "elapsed": 0.007736921310424805,
       "initial": 0,
       "n": 0,
       "ncols": null,
       "nrows": 15,
       "postfix": null,
       "prefix": "",
       "rate": null,
       "total": 39,
       "unit": "it",
       "unit_divisor": 1000,
       "unit_scale": false
      },
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b4d406c5c7144773a6cc9698e30b9828",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/39 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Error for question Which citations are referred to in the discussion about safety investigations into pretraining data and pretrained models?, ValidationError(model='Response', errors=[{'loc': ('__root__',), 'msg': 'Expecting value: line 1 column 1 (char 0)', 'type': 'value_error.jsondecode', 'ctx': {'msg': 'Expecting value', 'doc': 'Empty Response', 'pos': 0, 'lineno': 1, 'colno': 1}}])\n"
     ]
    }
   ],
   "source": [
    "from json import JSONDecodeError\n",
    "\n",
    "for idx, node in enumerate(tqdm(nodes[:39])):\n",
    "    node_questions_0 = node_questions_lists[idx]\n",
    "    for question in node_questions_0:\n",
    "        try:\n",
    "            # note: we don't need to use response, events are logged through fine-tuning handler\n",
    "            gpt4_query_engine.query(question)\n",
    "        except Exception as e:\n",
    "            print(f\"Error for question {question}, {repr(e)}\")\n",
    "            pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Wrote 83 examples to llama2_citation_events.jsonl\n"
     ]
    }
   ],
   "source": [
    "finetuning_handler.save_finetuning_events(\"llama2_citation_events.jsonl\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Setup Fine-tuning\n",
    "\n",
    "We kick off fine-tuning over the generated dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.finetuning import OpenAIFinetuneEngine\n",
    "\n",
    "finetune_engine = OpenAIFinetuneEngine(\n",
    "    \"gpt-3.5-turbo\",\n",
    "    \"llama2_citation_events.jsonl\",\n",
    "    # start_job_id=\"<start-job-id>\"  # if you have an existing job, can specify id here\n",
    "    validate_json=False,  # openai validate json code doesn't support function calling yet\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "finetune_engine.finetune()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<FineTuningJob fine_tuning.job id=ftjob-ATYm4yZHP1QvXs1wx85Ix79F at 0x1752b6b60> JSON: {\n",
       "  \"object\": \"fine_tuning.job\",\n",
       "  \"id\": \"ftjob-ATYm4yZHP1QvXs1wx85Ix79F\",\n",
       "  \"model\": \"gpt-3.5-turbo-0613\",\n",
       "  \"created_at\": 1696497663,\n",
       "  \"finished_at\": 1696498092,\n",
       "  \"fine_tuned_model\": \"ft:gpt-3.5-turbo-0613:llamaindex::86EwPw83\",\n",
       "  \"organization_id\": \"org-1ZDAvajC6v2ZtAP9hLEIsXRz\",\n",
       "  \"result_files\": [\n",
       "    \"file-wabcIIxjLqvhqOVohf4qSmE7\"\n",
       "  ],\n",
       "  \"status\": \"succeeded\",\n",
       "  \"validation_file\": null,\n",
       "  \"training_file\": \"file-WbYcsinIbH8vyCAstcoFEr92\",\n",
       "  \"hyperparameters\": {\n",
       "    \"n_epochs\": 3\n",
       "  },\n",
       "  \"trained_tokens\": 132678,\n",
       "  \"error\": null\n",
       "}"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "finetune_engine.get_current_job()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Use within RAG Pipeline\n",
    "\n",
    "Let's plug the fine-tuned LLM into a full RAG pipeline that outputs structured outputs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ft_llm = finetune_engine.get_finetuned_model(temperature=0.3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.core import VectorStoreIndex\n",
    "\n",
    "vector_index = VectorStoreIndex(nodes=nodes)\n",
    "query_engine = vector_index.as_query_engine(\n",
    "    output_cls=Response, similarity_top_k=1, llm=ft_llm\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# setup baseline as well\n",
    "base_index = VectorStoreIndex(nodes=nodes)\n",
    "base_query_engine = base_index.as_query_engine(\n",
    "    output_cls=Response, similarity_top_k=1, llm=gpt_35_llm\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{\"citations\": [{\"author\": \"Lin et al.\", \"year\": 2021, \"desc\": \"TruthfulQA, used for LLM hallucinations to measure whether a language model is truthful in generating answers to questions while being informative at the same time.\"}]}\n"
     ]
    }
   ],
   "source": [
    "query_str = \"\"\"\\\n",
    "Which citation is used to measure the truthfulness of Llama 2? \\\n",
    "\"\"\"\n",
    "# query_str = \"\"\"\\\n",
    "# Which citation corresponds to the concept of collecting data that represents \\\n",
    "# empirically sampled human preferences in RLHF?\\\n",
    "# \"\"\"\n",
    "# query_str = \"Which citations in the paper discuss the development and release of Llama 2?\"\n",
    "# query_str = \"Which citations are mentioned in the section on RLHF Results?\"\n",
    "# query_str = \"Which citation discusses the carbon output related to the production of AI hardware?\"\n",
    "\n",
    "\n",
    "response = query_engine.query(query_str)\n",
    "print(str(response))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{\"citations\": [{\"author\": \"Lin et al.\", \"year\": 2021, \"desc\": \"TruthfulQA\"}]}\n"
     ]
    }
   ],
   "source": [
    "base_response = base_query_engine.query(query_str)\n",
    "print(str(base_response))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# view sources\n",
    "print(response.source_nodes[0].get_content())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{\"citations\": [{\"author\": \"Lin et al.\", \"year\": 2021, \"desc\": \"TruthfulQA, used for LLM hallucinations to measure whether a language model is truthful in generating answers to questions while being informative at the same time.\"}]}\n"
     ]
    }
   ],
   "source": [
    "# as a reference, take a look at GPT-4 response\n",
    "gpt4_response = gpt4_query_engine.query(query_str)\n",
    "print(str(gpt4_response))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "llama_index_v2",
   "language": "python",
   "name": "llama_index_v2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
